import os
import pickle
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision import models
import numpy as np
import pandas as pd
from CIFAR10_GAN import *
from TV_estimation import *

new_path = "/users/eval/discriminative_approach"  # Replace with your desired path
os.chdir(new_path)
current_path = os.getcwd()
print("Current Path:", current_path)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


##-----------------data loading-----------------##

##--------- real data ---------##

# Define the transformations: convert images to PyTorch tensors and normalize them
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert images to PyTorch tensors
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize RGB channels
])

# Download and load the CIFAR-10 training dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_data_real = torch.tensor(trainset.data).permute(0, 3, 1, 2).to(device)
train_data_real_y = torch.tensor(trainset.targets).to(device)

# Download and load the CIFAR-10 test dataset
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
test_data_real = torch.tensor(testset.data).permute(0, 3, 1, 2).to(device)
test_data_real_y = torch.tensor(testset.targets).to(device)

print(f"shape of train_real_image: {train_data_real.shape}")
print(f"shape of test_real_image: {test_data_real.shape}")
print(f"shape of train_real_label: {train_data_real_y.shape}")
print(f"shape of test_real_label: {test_data_real_y.shape}")


##--------- generated data ---------##

Train_100_image, Train_100_y = torch.load('GAN_100_train.pth').values()
Train_300_image, Train_300_y = torch.load('GAN_300_train.pth').values()
Train_500_image, Train_500_y = torch.load('GAN_500_train.pth').values()

Test_100_image, Test_100_y = torch.load('GAN_100_test.pth').values()
Test_300_image, Test_300_y = torch.load('GAN_300_test.pth').values()
Test_500_image, Test_500_y = torch.load('GAN_500_test.pth').values()


# preview
print(f"shape of Train_100_image: {Train_100_image.shape}")
print(f"shape of Train_300_image: {Train_300_image.shape}")
print(f"shape of Train_500_image: {Train_500_image.shape}")

print(f"shape of Train_100_image_label: {Train_100_y.shape}")
print(f"shape of Train_300_image_label: {Train_300_y.shape}")
print(f"shape of Train_500_image_label: {Train_500_y.shape}")

print(f"shape of Test_100_image: {Test_100_image.shape}")
print(f"shape of Test_300_image: {Test_300_image.shape}")
print(f"shape of Test_500_image: {Test_500_image.shape}")

print(f"shape of Test_100_image_label: {Test_100_y.shape}")
print(f"shape of Test_300_image_label: {Test_300_y.shape}")
print(f"shape of Test_500_image_label: {Test_500_y.shape}")


##-----------------ResNet-18 embedding-----------------##

# Load a pre-trained ResNet-18
resnet18 = models.resnet18(pretrained=True)

def change_layers(model):
    # model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    model.fc = nn.Linear(512, 50, bias=True)
    return model

model = change_layers(resnet18).to(device)

# Define the transformation pipeline
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to 224x224 (ResNet-18 input size)
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # Normalize using ImageNet statistics
])

All_error = []

for i in range(10):
    for G in [100, 300, 500]:
        if G==100:
            Train_Syn_X = Train_100_image[Train_100_y==i].to(device)
            Test_Syn_X = Test_100_image[Test_100_y == i].to(device)
            # print(f"the shape of Train_100_image in class {i} is {Train_Syn_X.shape}")
            # print(f"the shape of Test_100_image in class {i} is {Test_Syn_X.shape}")
        if G==300:
            Train_Syn_X = Train_300_image[Train_300_y==i].to(device)
            Test_Syn_X = Test_300_image[Test_300_y == i].to(device)
            # print(f"the shape of Train_100_image in class {i} is {Train_Syn_X.shape}")
            # print(f"the shape of Test_100_image in class {i} is {Test_Syn_X.shape}")
        if G==500:
            Train_Syn_X = Train_500_image[Train_500_y==i].to(device)
            Test_Syn_X = Test_500_image[Test_500_y == i].to(device)
            # print(f"the shape of Train_100_image in class {i} is {Train_Syn_X.shape}")
            # print(f"the shape of Test_100_image in class {i} is {Test_Syn_X.shape}")
        
        Train_Real_X = train_data_real[train_data_real_y == i].to(device)
        Test_Real_X = test_data_real[test_data_real_y == i].to(device)
        # print(Train_Real_X.shape[0],Test_Real_X.shape[0] )
        Train_size = min(Train_Real_X.shape[0], Train_Syn_X.shape[0])
        Test_size = min(Test_Real_X.shape[0], Test_Syn_X.shape[0])
        # print(f"the training dataset size in class {i} is {Train_size}")
        # print(f"the test dataset size in class {i} is {Test_size}")

        X_train = torch.cat((Train_Real_X[0:Train_size, :],
                                  Train_Syn_X[0:Train_size, :]),dim=0).to(device)
        X_test = torch.cat((Test_Real_X[0:Test_size, :],
                                 Test_Syn_X[0:Test_size, :]),dim=0).to(device)
        
        # print(X_train.shape, X_test.shape)

        batch_size = 200
        train_batches = torch.split(X_train, batch_size)
        test_batches = torch.split(X_test, batch_size)
        
        x_train = []
        x_test = []
        for j in range(len(train_batches)):
            current_batch = train_batches[j]
            with torch.no_grad():
                current_x_train = model(transform(current_batch))
            x_train.append(current_x_train)
            
        x_train = torch.cat(x_train, dim=0)

        for j in range(len(test_batches)):
            current_batch = test_batches[j]
            with torch.no_grad():
                current_x_test = model(transform(current_batch))
            x_test.append(current_x_test)
            
        x_test = torch.cat(x_test, dim=0)

        y_train = np.concatenate([np.ones(Train_size), np.zeros(Train_size)])
        y_test = np.concatenate([np.ones(Test_size), np.zeros(Test_size)])
        x_real, x_syn = x_train[y_train == 1], x_train[y_train == 0]
        # print(x_real.shape, x_syn.shape)

        mu_1_bar, mu_2_bar = x_real.cpu().numpy().mean(axis=0), x_syn.cpu().numpy().mean(axis=0)
        Sigma_1_bar, Sigma_2_bar = np.cov(x_real.cpu().numpy(), rowvar=False), np.cov(x_syn.cpu().numpy(), rowvar=False)
        
        DisE = Dist_TV(x_train.cpu().numpy(), x_test.cpu().numpy(), y_train, y_test)
        KDE = KDE_TV(x_real.cpu().numpy(), x_syn.cpu().numpy())
        PE = MC_TV_Baseline(mu_1_bar, Sigma_1_bar, mu_2_bar, Sigma_2_bar)
        All_error.append([i,G,DisE,KDE,PE])
        print(i,G)


Result = pd.DataFrame(All_error)
Result.to_csv('Resnet18_50')

Result.columns = ['dig','G','DisE','KDE','PE']
print(Result.groupby(['G']).mean())
print(Result.groupby(['G']).std())
